# encoding=utf8
import os
import pickle
import logging
import time
import pandas as pd
import numpy as np
import torch
import random
from sklearn.model_selection import StratifiedKFold
from Leaners import Dataset, BertDNN, Bert_BiLSTM, Bert_LstmAtt, Bert_TextCNN, MLP
from Leaners import train_and_eval

from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier, RandomForestClassifier, BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def save_results_to_disk(fold_results, model_name, save_directory):
    # 构造文件名，包括模型名称
    filename = f"{model_name}_fold_results.pkl"
    save_path = os.path.join(save_directory, filename)
    # 检查save_directory是否存在，如果不存在，则创建它
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    # 使用pickle将fold_results保存到指定的文件中
    with open(save_path, 'wb') as f:
        pickle.dump(fold_results, f)
    print(f"Results saved to {save_path}")


# 设置日志配置
def setup_logging(log_directory):
    if not os.path.exists(log_directory):
        os.makedirs(log_directory)
    log_file = os.path.join(log_directory, 'training_log.log')
    logging.basicConfig(filename=log_file, level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')


# 计算模型的可训练参数总数
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def build_ensemble_dataset():
    ensemble_train = []
    ensemble_test = []
    for idx, base_learner in enumerate(['Bert_LstmAtt', 'Bert_TextCNN', 'BertDNN', 'Bert_BiLSTM']):
        with open(f'data/{base_learner}_fold_results.pkl', 'rb') as f:
            data = pickle.load(f)
            # fold---val&test--epoch
            logits_val = []
            logits_test = []
            for fold in range(len(data)):
                fold_logits_val = np.concatenate(data[fold][0][-1], axis=0)
                fold_logits_test = np.concatenate(data[fold][1][-1], axis=0)
                logits_val.append(fold_logits_val)
                logits_test.append(fold_logits_test)
            base_learner_train = np.concatenate(logits_val, axis=0)
            base_learner__test = np.mean(logits_test, axis=0)
            ensemble_train.append(base_learner_train)
            ensemble_test.append(base_learner__test)
    ensemble_train = np.concatenate(ensemble_train, axis=1)
    ensemble_test = np.concatenate(ensemble_test, axis=1)
    return ensemble_train, ensemble_test


def run_base_learners(cfg):
    torch.cuda.set_device(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    df_data_train = pd.read_excel('train.xlsx')
    df_data_test = pd.read_excel('test.xlsx')
    train_labels = df_data_train['label']
    test_labels = df_data_test['label']
    train_data = df_data_train['text']
    test_data = df_data_test['text']
    output_dim = df_data_test['label'].nunique()
    cfg['class_num'] = output_dim

    test_dataset = Dataset(texts=test_data, labels=test_labels, pretrained_model=cfg['pre_trained_model'])
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=cfg['batch_size'], shuffle=False)

    skf = StratifiedKFold(n_splits=cfg['Base_learner_num'], shuffle=True, random_state=cfg['seed'])

    for idx, base_learner in enumerate([Bert_LstmAtt, Bert_TextCNN, BertDNN, Bert_BiLSTM]):
        model_fold_results = []
        for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_data, train_labels)):
            train_subset = Dataset(texts=train_data.iloc[train_idx], labels=train_labels.iloc[train_idx],
                                   pretrained_model=cfg['pre_trained_model'])
            val_subset = Dataset(texts=train_data.iloc[val_idx], labels=train_labels.iloc[val_idx],
                                 pretrained_model=cfg['pre_trained_model'])
            train_dataloader = torch.utils.data.DataLoader(train_subset, batch_size=cfg['batch_size'], shuffle=True)
            val_dataloader = torch.utils.data.DataLoader(val_subset, batch_size=cfg['batch_size'], shuffle=False)

            model = base_learner(cfg)
            start_time = time.time()  # 开始计时
            output_info = train_and_eval(model, train_dataloader, [val_dataloader, test_dataloader], cfg['lr'], cfg['epochs'], device)
            end_time = time.time()  # 结束计时
            elapsed_time = end_time - start_time
            trainable_params = count_trainable_params(model)
            logging.info(
                f"Training and evaluation took {elapsed_time:.2f} seconds for model {model.__class__.__name__} (#{trainable_params}) on fold {fold_idx}, {cfg['epochs']} epochs")
            model_fold_results.append(output_info)

        model_name = base_learner.__name__  # 获取模型的名称
        save_results_to_disk(model_fold_results, model_name, cfg['base_learner_out_idr'])

    pass


def run_ensemble_ml_learners(cfg):
    ensemble_train, ensemble_test = build_ensemble_dataset()
    df_data_train = pd.read_excel('train.xlsx')
    df_data_test = pd.read_excel('test.xlsx')
    test_labels = df_data_test['label']
    train_labels = df_data_train['label']
    output_dim = df_data_test['label'].nunique()
    cfg['class_num'] = output_dim
    skf = StratifiedKFold(n_splits=cfg['Base_learner_num'], shuffle=True, random_state=cfg['seed'])
    ensemble_train_labels = []
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_labels, train_labels)):
        ensemble_train_labels += train_labels.iloc[val_idx].values.tolist()
    # 训练并进行分类预测
    # 创建分类器数组
    classifiers = [
        AdaBoostClassifier(n_estimators=cfg['ensemble_n_estimators'], random_state=cfg['seed']),
        GradientBoostingClassifier(n_estimators=cfg['ensemble_n_estimators'], random_state=cfg['seed']),
        RandomForestClassifier(n_estimators=cfg['ensemble_n_estimators'], random_state=cfg['seed']),
        BaggingClassifier(base_estimator=DecisionTreeClassifier(random_state=cfg['seed']), n_estimators=cfg['ensemble_n_estimators'], random_state=cfg['seed'])
    ]

    for clf in classifiers:
        start_time = time.time()
        clf.fit(ensemble_train, ensemble_train_labels)
        training_time = time.time() - start_time
        y_pred = clf.predict(ensemble_test)
        accuracy = accuracy_score(test_labels, y_pred)
        logging.info(f"{clf.__class__.__name__} Accuracy: {accuracy} Training Time: {training_time:.4f} seconds")


def run_ensemble_MLP_learners(cfg):
    torch.cuda.set_device(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ensemble_train, ensemble_test = build_ensemble_dataset()
    df_data_train = pd.read_excel('train.xlsx')
    df_data_test = pd.read_excel('test.xlsx')
    test_labels = df_data_test['label']
    train_labels = df_data_train['label']
    output_dim = df_data_test['label'].nunique()
    cfg['class_num'] = output_dim
    skf = StratifiedKFold(n_splits=cfg['Base_learner_num'], shuffle=True, random_state=cfg['seed'])
    ensemble_train_labels = []
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_labels, train_labels)):
        ensemble_train_labels += train_labels.iloc[val_idx].values.tolist()

    # 初始化模型
    input_size = ensemble_train.shape[1]
    model = MLP(input_size, cfg['ensemble_hidden_size'], cfg['class_num'])
    model = model.to(device)
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 训练模型
    num_epochs = 50

    ensemble_train = torch.tensor(ensemble_train, dtype=torch.float32).to(device)
    ensemble_train_labels = torch.tensor(ensemble_train_labels, dtype=torch.long).to(device)
    ensemble_test = torch.tensor(ensemble_test, dtype=torch.float32).to(device)
    test_labels = torch.tensor(test_labels, dtype=torch.long).to(device)

    start_time = time.time()
    bst_acc = 0
    for epoch in range(num_epochs):
        # 将模型设置为训练模式
        model.train()
        # 前向传播
        outputs = model(ensemble_train)
        loss = criterion(outputs, ensemble_train_labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 打印每个 epoch 的损失
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

        # 在测试集上进行预测
        model.eval()
        with torch.no_grad():
            outputs = model(ensemble_test)
            _, predicted = torch.max(outputs, 1)

        # 计算分类准确率
        accuracy = accuracy_score(test_labels.cpu().numpy(), predicted.cpu().numpy())
        if accuracy > bst_acc:
            bst_acc = accuracy
        print("Test Accuracy:", accuracy)
    training_time = time.time() - start_time
    logging.info(f"MLP ensemble Accuracy: {accuracy} Training Time: {training_time:.4f} seconds")


if __name__ == '__main__':

    model_cfg = {
        'pre_trained_model': 'hfl/chinese-bert-wwm-ext',
        'embed_size': 768,
        'hidden_size_dnn_1': 128,
        'hidden_size_dnn_2': 32,
        'dropout_dnn_1': 0.1,
        'dropout_dnn_2': 0.1,
        'hidden_size_lstm': 32,
        'num_layers_lstm': 1,
        'fc_hidden_size_lstm': 128,
        'dropout': 0.1,
        'kernel_sizes': [3, 4, 5],
        'num_channels': [100, 100, 100],
        'ensemble_n_estimators': 50,
        'ensemble_hidden_size': 128
    }
    running_cfg = {
        'seed': 42,
        'epochs': 10,
        'lr': 1e-6,
        'batch_size': 8,
        'Base_learner_num': 5,
        'base_learner_out_idr': './data',
        'log_path': './'
    }

    model_cfg.update(running_cfg)
    setup_seed(model_cfg['seed'])
    setup_logging(model_cfg['log_path'])
    # run_base_learners(model_cfg)
    # run_ensemble_ml_learners(model_cfg)
    run_ensemble_MLP_learners(model_cfg)


